"""
大语言模型用于氢能企业分类任务
"""

from datasets import load_dataset
from settings import PROMPT_CLS

from utils import glm4_vllm


def trans_data(item):
    business = item.get("经营范围")
    industry_name = item.get("企业名称")
    large = item.get("大类名称")
    mid = item.get("中类名称")
    small = item.get("小类名称")
    text = f"{industry_name}：经营范围包括{business}；登记产业链信息为:{large} {mid} {small}；"
    item["prompt"] = PROMPT_CLS.format(industry_info=text)
    return item


dataset = load_dataset(
    "csv",
    data_files="../hydrogen/data_clean/全国5万家氢能企业名单_utf8.csv",
    split="train",
)

dataset = dataset.map(trans_data)

prompts = []
for item in dataset:
    prompts.append([{"role": "user", "content": item["prompt"]}])

glm4_vllm(prompts, "objs/5w_cls_240902.pkl")
# nohup python vllm_cls.py > vllm_cls.log 2>&1 &
